# load packages
library(tidyverse)
library(brms)
library(tidybayes)
library(kableExtra)
library(broom)
library(foreach)
library(doMC)

fit202 <- read_rds("output/fits/model-01-structure-2-0-2.rds") 

df <- read_rds("output/rescaled-data.rds") %>%
  glimpse()

# back of the envelop comparison

# create fake data set
fake_df <- crossing(treat_indicator = c(0, 1),
                    policy = unique(df$policy)) %>%
  left_join(distinct(select(df, policy, stem, awareness, awareness_std, issue, category))) %>%
  mutate(treat_indicator_std = treat_indicator - 0.5) %>%
  glimpse()


# simulate quantities of interest
qi_df <- fake_df %>%
  add_epred_draws(fit202) %>%
  ungroup() %>% 
  select(-ends_with("_std")) %>%
  group_by(policy, stem, issue, category, awareness, .draw) %>%
  select(-.chain, -.iteration, -.row) %>%
  spread(treat_indicator, .epred) %>%
  mutate(treatment_effect = `1` - `0`) %>%
  select(-`0`, -`1`) %>% 
  select(policy, stem, issue, category, awareness, draw = .draw, treatment_effect) %>%
  ungroup() %>%
  mutate(stem = reorder(stem, treatment_effect)) %>%
  mutate(category = fct_recode(category, 
                               `Economic Policy` = "Economic",
                               `Social Policy` = "Social")) %>%
  glimpse()

# summaries qis
qi_sum_df <- qi_df %>%
  group_by(policy, stem, issue, category, awareness) %>%
  summarize(post_avg = mean(treatment_effect), 
            post95 = quantile(treatment_effect, 0.95), 
            post05 = quantile(treatment_effect, 0.05), 
            prob_gt0 = mean(treatment_effect > 0),
            ht_result = case_when(prob_gt0 > 0.95 ~ "Strong Evidence", 
                                  prob_gt0 > 0.90 ~ "Moderate Evidence", 
                                  TRUE ~ "Weak or No Evidence"),
            ht_result = reorder(ht_result, prob_gt0)) %>%
  ungroup() %>%
  mutate(stem = reorder(stem, post_avg)) %>%
  glimpse()

qi_sum_df %>%
  glimpse()

econ <- qi_sum_df %>%
  filter(category == "Economic Policy") %>%
  select(policy, stem, issue, category, awareness, post_avg) %>%
  rename_with(~str_c("econ_", .), everything()) %>%
  glimpse()

social <- qi_sum_df %>%
  filter(category == "Social Policy") %>%
  select(policy, stem, issue, category, awareness, post_avg) %>%
  rename_with(~str_c("social_", .), everything()) %>%
  glimpse()

combos <- crossing(econ_policy = econ$econ_policy,
                   social_policy = social$social_policy) %>%
  left_join(econ) %>%
  left_join(social) %>%
  mutate(diff = econ_post_avg - social_post_avg,
         wi_5 = abs(econ_awareness - social_awareness) < 0.05,
         wi_5_fct = ifelse(abs(econ_awareness - social_awareness) < 0.05,
                       "Awareness Within Five Percentage Points",
                       "Awareness Not Within Five Percentage Points"),
         gt = econ_awareness > social_awareness) %>%
  glimpse()

sum(combos$diff > 0)
mean(combos$diff > 0)
mean(combos$diff)
sd(combos$diff)

quantile(combos$diff)

sum(combos$wi_5)
sum(combos$diff[combos$wi_5] < 0)
mean(combos$diff[combos$wi_5])
sd(combos$diff[combos$wi_5])
quantile(combos$diff[combos$wi_5])

sum(combos$gt)
mean(combos$diff[combos$gt] > 0)

ggplot(combos, aes(x = diff, fill = wi_5)) + 
  scale_fill_brewer(type = "qual", palette = 2,
                    guide = guide_legend(
                      direction = "horizontal",
                      title.position = "top",
                      label.position = "bottom",
                      label.hjust = 1,
                      label.vjust = 1
                    )) + 
  geom_histogram() + 
  theme_bw() + theme(legend.position = "bottom", legend.title.align = 0) + 
  labs(x = "Difference in Treatment Effect",
       y = "Count", 
       fill = "Comparability of Awareness",
       title = "Distribution of Estimates Among Possible Multi-Armed Studies")
ggsave("figs/figa3-multi-armed.tiff",
       width = 6, height = 4, scale = 1.0)

mean(combos$diff)

# find range of awareness
mm <- df %>%
  group_by(category) %>%
  summarize(min = min(awareness),
            max = max(awareness)) %>%
  filter(category != "Foreign Policy") %>%
  ungroup() %>%
  summarize(hi = min(max),
            lo = max(min)) %>%
  glimpse()

# truth
awareness_overlap <- unique(df$awareness[
  df$awareness <= mm$hi &
    df$awareness >= mm$lo
  ])

# matching awareness levels
registerDoMC(4)
sample_size <- rep(c(3000), 5) # needs to be divisible by 4
t_tests <- foreach(i = 1:length(sample_size)) %do% {
  
  sampled_awareness <- sample(awareness_overlap, 1)
  truth <- crossing(category = c("Economic", "Social"), 
                    treat_indicator = 0:1,
                    awareness = sampled_awareness) %>%
    left_join(distinct(select(df, awareness, awareness_std))) %>%
    mutate(treat_indicator_std = treat_indicator - 0.5) %>%
    add_epred_draws(fit202, re_formula = ~ (1 + treat_indicator_std | category)) %>%
    ungroup() 
  
  truth_tbl <- tidy(lm(.epred ~ category*treat_indicator_std, data = truth)) %>%
    filter(term == "categorySocial:treat_indicator_std")
  
  study <- crossing(category = c("Economic", "Social"), 
                    treat_indicator = 0:1,
                    awareness = sampled_awareness) %>%
    left_join(distinct(select(df, awareness, awareness_std))) %>%
    mutate(treat_indicator_std = treat_indicator - 0.5) %>% 
    slice(rep(1:n(), each = sample_size[i]/4)) %>%
    mutate(policy = case_when(category == "Economic" ~ "Random Economic Policy",
                              category == "Social" ~ "Random Social Policy")) %>%
    add_predicted_draws(fit202, allow_new_levels = TRUE, sample_new_levels = "uncertainty", ndraws = 1) %>%
    ungroup()
  
  
  lm(.prediction ~ category*treat_indicator_std, data = study)
}

sims <- tibble(t_test = t_tests) %>%
  mutate(tidy_t_test = map(t_test, broom::tidy)) %>%
  mutate(study_id = 1:n()) %>%
  unnest(cols = tidy_t_test) %>% 
  #mutate(capture = ifelse(conf.low < truth_tbl$estimate & conf.high > truth_tbl$estimate, "Captured", "Not Captured")) %>%
  filter(term == "categorySocial:treat_indicator_std") %>%
  glimpse()

ggplot(sims, aes(x = std.error, y = study_id)) + 
  geom_point()

ggplot(sims, aes(xmin = estimate - std.error, xmax = estimate + std.error, y = study_id)) + 
  geom_errorbarh() + 
  geom_vline(xintercept = truth_tbl$estimate)
